{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rOvvWAVTkMR7"
      },
      "source": [
        "# Waste identification with instance segmentation in TensorFlow\n",
        "\n",
        "Welcome to the Instance Segmentation Colab! This notebook will take you through the steps of running an \"out-of-the-box\" Mask RCNN Instance Segmentation model on images."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HVTXSC07QwfG"
      },
      "source": [
        "Given 3 different Mask RCNN models for the material type, material form type and plastic type, your goal is to perform inference with any of the models and visualize the results."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AQUsAE0TRkmh"
      },
      "source": [
        "To finish this task, a proper path for the saved models and a single image needs to be provided. The path to the labels on which the models are trained is in the  waste_identification_ml directory inside the Tensorflow Model Garden repository. The label files are inferred automatically once you select the ML model by which you want to do the inference."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vPs64QA1Zdov"
      },
      "source": [
        "## Imports and Setup\n",
        "\n",
        "Let's start with the base imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OtfgxYR-oT8J"
      },
      "outputs": [],
      "source": [
        "# install model-garden official\n",
        "!pip install tf-models-official"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yn5_uV1HLvaz"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pathlib\n",
        "import cv2\n",
        "import logging\n",
        "logging.disable(logging.WARNING)\n",
        "\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import numpy as np\n",
        "from six import BytesIO\n",
        "from PIL import Image\n",
        "from six.moves.urllib.request import urlopen\n",
        "\n",
        "from official.vision.ops.preprocess_ops import normalize_image\n",
        "\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "14bNk1gzh0TN"
      },
      "source": [
        "## Visualization tools\n",
        "\n",
        "To visualize the images with the proper detected boxes and segmentation masks, we will use the TensorFlow Object Detection API. To install it we will clone the repo."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oi28cqGGFWnY"
      },
      "outputs": [],
      "source": [
        "# Clone the tensorflow models repository\n",
        "!git clone --depth 1 https://github.com/tensorflow/models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yX3pb_pXDjYA"
      },
      "source": [
        "Intalling the Object Detection API"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NwdsBdGhFanc"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "sudo apt install -y protobuf-compiler\n",
        "cd models/research/\n",
        "protoc object_detection/protos/*.proto --python_out=.\n",
        "cp object_detection/packages/tf2/setup.py .\n",
        "python -m pip install ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3yDNgIx-kV7X"
      },
      "source": [
        "Now we can import the dependencies we will need later"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2JCeQU3fkayh"
      },
      "outputs": [],
      "source": [
        "from object_detection.utils import label_map_util\n",
        "from object_detection.utils import visualization_utils as viz_utils\n",
        "from object_detection.utils import ops as utils_ops\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XRUr9Aiwuho7"
      },
      "source": [
        "## Import pre-trained models from the Waste Identification project"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BWMh8UWl7eZA"
      },
      "outputs": [],
      "source": [
        "# download the model weights from the Google's repo\n",
        "!wget https://storage.googleapis.com/tf_model_garden/vision/waste_identification_ml/material_model.zip\n",
        "!wget https://storage.googleapis.com/tf_model_garden/vision/waste_identification_ml/material_form_model.zip\n",
        "!wget https://storage.googleapis.com/tf_model_garden/vision/waste_identification_ml/plastic_types_model.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1oiSODmn7gh-"
      },
      "outputs": [],
      "source": [
        "# unziping the folders\n",
        "%%bash\n",
        "mkdir material material_form plastic_type\n",
        "unzip material_model.zip -d material/\n",
        "unzip material_form_model.zip -d material_form/\n",
        "unzip plastic_types_model.zip -d plastic_type/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ey-8Ij2sKjkD"
      },
      "outputs": [],
      "source": [
        "ALL_MODELS = {\n",
        "'material_model' : 'material/saved_model/saved_model/',\n",
        "'material_form_model' : 'material_form/saved_model/saved_model/',\n",
        "'plastic_model' : 'plastic_type/saved_model/saved_model/'\n",
        "}\n",
        "\n",
        "# path to an image\n",
        "IMAGES_FOR_TEST = {\n",
        "  'Image1' : 'models/official/projects/waste_identification_ml/pre_processing/config/sample_images/image_2.png'\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IogyryF2lFBL"
      },
      "source": [
        "## Utilities\n",
        "\n",
        "Run the following cell to create some utils that will be needed later:\n",
        "\n",
        "- Helper method to load an image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9XXfEdD9PMKn"
      },
      "outputs": [],
      "source": [
        "# Inputs to preprocess functions\n",
        "\n",
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (1, h, w, 3)\n",
        "  \"\"\"\n",
        "  image = None\n",
        "  if(path.startswith('http')):\n",
        "    response = urlopen(path)\n",
        "    image_data = response.read()\n",
        "    image_data = BytesIO(image_data)\n",
        "    image = Image.open(image_data)\n",
        "  else:\n",
        "    image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "    image = Image.open(BytesIO(image_data))\n",
        "\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (1, im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "\n",
        "def build_inputs_for_segmentation(image):\n",
        "  \"\"\"Builds segmentation model inputs for serving.\"\"\"\n",
        "  # Normalizes image with mean and std pixel values.\n",
        "  image = normalize_image(image)\n",
        "  return image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6917xnUSlp9x"
      },
      "source": [
        "## Build a instance segmentation model and load pre-trained model weights\n",
        "\n",
        "Here we will choose which Instance Segmentation model we will use.\n",
        "If you want to change the model to try other architectures later, just change the next cell and execute following ones."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HtwrSqvakTNn"
      },
      "outputs": [],
      "source": [
        "# @title Model Selection { display-mode: \"form\", run: \"auto\" }\n",
        "model_display_name = 'material_form_model' # @param ['material_model','material_form_model','plastic_model']\n",
        "model_handle = ALL_MODELS[model_display_name]\n",
        "\n",
        "print('Selected model:'+ model_display_name)\n",
        "print('Model Handle at TensorFlow Hub: {}'.format(model_handle))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NKtD0IeclbL5"
      },
      "source": [
        "### Load label map data (for plotting).\n",
        "\n",
        "Label maps correspond index numbers to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine.\n",
        "\n",
        "We are going, for simplicity, to load from the repository that we loaded the Object Detection API code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Kwqa0T1NTUf"
      },
      "outputs": [],
      "source": [
        "# @title Labels for the above model { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "if model_display_name == 'material_model':\n",
        "  PATH_TO_LABELS = './models/official/projects/waste_identification_ml/pre_processing/config/data/material_labels.pbtxt'\n",
        "elif model_display_name == 'material_form_model':\n",
        "  PATH_TO_LABELS = './models/official/projects/waste_identification_ml/pre_processing/config/data/material_form_labels.pbtxt'\n",
        "elif model_display_name == 'plastic_model':\n",
        "  PATH_TO_LABELS = './models/official/projects/waste_identification_ml/pre_processing/config/data/plastic_type_labels.pbtxt'\n",
        "\n",
        "print('Labels selected for',model_display_name)\n",
        "print('\\n')\n",
        "category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)\n",
        "category_index"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "muhUt-wWL582"
      },
      "source": [
        "## Loading the selected model from TensorFlow Hub\n",
        "\n",
        "Here we just need the model handle that was selected and use the Tensorflow Hub library to load it to memory.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rBuD07fLlcEO"
      },
      "outputs": [],
      "source": [
        "print('loading model...')\n",
        "model = tf.saved_model.load(model_handle)\n",
        "print('model loaded!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GIawRDKPPnd4"
      },
      "source": [
        "## Loading an image\n",
        "\n",
        "Let's try the model on a simple image.\n",
        "\n",
        "Here are some simple things to try out if you are curious:\n",
        "* Try running inference on your own images, just upload them to colab and load the same way it's done in the cell below.\n",
        "* Modify some of the input images and see if detection still works.  Some simple things to try out here include flipping the image horizontally, or converting to grayscale (note that we still expect the input image to have 3 channels).\n",
        "\n",
        "**Be careful:** when using images with an alpha channel, the model expect 3 channels images and the alpha will count as a 4th.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hX-AWUQ1wIEr"
      },
      "outputs": [],
      "source": [
        "#@title Image Selection (don't forget to execute the cell!) { display-mode: \"form\"}\n",
        "selected_image = 'Image1' # @param ['Image1']\n",
        "flip_image_horizontally = False #@param {type:\"boolean\"}\n",
        "convert_image_to_grayscale = False #@param {type:\"boolean\"}\n",
        "\n",
        "image_path = IMAGES_FOR_TEST[selected_image]\n",
        "image_np = load_image_into_numpy_array(image_path)\n",
        "\n",
        "# Flip horizontally\n",
        "if(flip_image_horizontally):\n",
        "  image_np[0] = np.fliplr(image_np[0]).copy()\n",
        "\n",
        "# Convert image to grayscale\n",
        "if(convert_image_to_grayscale):\n",
        "  image_np[0] = np.tile(\n",
        "    np.mean(image_np[0], 2, keepdims=True), (1, 1, 3)).astype(np.uint8)\n",
        "\n",
        "print('min:',np.min(image_np[0]), 'max:', np.max(image_np[0]))\n",
        "plt.figure(figsize=(24,32))\n",
        "plt.imshow(image_np[0])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dkkBAgGcX65P"
      },
      "source": [
        "## Pre-processing an image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "97zIaKAhX-92"
      },
      "outputs": [],
      "source": [
        "# get an input size of images on which an Instance Segmentation model is trained\n",
        "detection_fn = model.signatures['serving_default']\n",
        "height= detection_fn.structured_input_signature[1]['inputs'].shape[1]\n",
        "width = detection_fn.structured_input_signature[1]['inputs'].shape[2]\n",
        "input_size = (height, width)\n",
        "print(input_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-K0V6KWiYYpD"
      },
      "outputs": [],
      "source": [
        "# apply pre-processing functions which were applied during training the model\n",
        "image_np_cp = cv2.resize(image_np[0], input_size[::-1], interpolation = cv2.INTER_AREA)\n",
        "image_np = build_inputs_for_segmentation(image_np_cp)\n",
        "image_np = tf.expand_dims(image_np, axis=0)\n",
        "image_np.get_shape()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ga1lccBpdxpd"
      },
      "outputs": [],
      "source": [
        "# display pre-processed image\n",
        "plt.figure(figsize=(24,32))\n",
        "plt.imshow(image_np[0])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FTHsFjR6HNwb"
      },
      "source": [
        "## Doing the inference\n",
        "\n",
        "To do the inference we just need to call our TF Hub loaded model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gb_siXKcnnGC"
      },
      "outputs": [],
      "source": [
        "# running inference\n",
        "results = detection_fn(image_np)\n",
        "\n",
        "# different object detection models have additional results\n",
        "# all of them are explained in the documentation\n",
        "result = {key:value.numpy() for key,value in results.items()}\n",
        "print(result.keys())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IZ5VYaBoeeFM"
      },
      "source": [
        "## Visualizing the results\n",
        "\n",
        "Here is where we will need the TensorFlow Object Detection API to show the squares from the inference step (and the keypoints when available).\n",
        "\n",
        "the full documentation of this method can be seen [here](https://github.com/tensorflow/models/blob/master/research/object_detection/utils/visualization_utils.py)\n",
        "\n",
        "Here you can, for example, set `min_score_thresh` to other values (between 0 and 1) to allow more detections in or to filter out more detections."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PMzURFjxxqF7"
      },
      "outputs": [],
      "source": [
        "# selecting parameters for visualization\n",
        "label_id_offset = 0\n",
        "min_score_thresh =0.6\n",
        "use_normalized_coordinates=True\n",
        "\n",
        "if use_normalized_coordinates:\n",
        "  # Normalizing detection boxes\n",
        "  result['detection_boxes'][0][:,[0,2]] /= height\n",
        "  result['detection_boxes'][0][:,[1,3]] /= width"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FILNrrDy0kUg"
      },
      "outputs": [],
      "source": [
        "# Visualize detection and masks\n",
        "if 'detection_masks' in result:\n",
        "  # we need to convert np.arrays to tensors\n",
        "  detection_masks = tf.convert_to_tensor(result['detection_masks'][0])\n",
        "  detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])\n",
        "\n",
        "  detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(\n",
        "            detection_masks, detection_boxes,\n",
        "              image_np.shape[1], image_np.shape[2])\n",
        "  detection_masks_reframed = tf.cast(detection_masks_reframed \u003e 0.5,\n",
        "                                      np.uint8)\n",
        "\n",
        "  result['detection_masks_reframed'] = detection_masks_reframed.numpy()\n",
        "viz_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_np_cp,\n",
        "      result['detection_boxes'][0],\n",
        "      (result['detection_classes'][0] + label_id_offset).astype(int),\n",
        "      result['detection_scores'][0],\n",
        "      category_index=category_index,\n",
        "      use_normalized_coordinates=use_normalized_coordinates,\n",
        "      max_boxes_to_draw=200,\n",
        "      min_score_thresh=min_score_thresh,\n",
        "      agnostic_mode=False,\n",
        "      instance_masks=result.get('detection_masks_reframed', None),\n",
        "      line_thickness=2)\n",
        "\n",
        "plt.figure(figsize=(24,32))\n",
        "plt.imshow(image_np_cp)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c75cSAeJ5JAQ"
      },
      "source": [
        "## Visualizing the masks only"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tt7RxYqhLpn9"
      },
      "outputs": [],
      "source": [
        "# collecting all masks and saving\n",
        "\n",
        "mask_count = np.sum(result['detection_scores'][0] \u003e= min_score_thresh)\n",
        "print('Total number of objects found are:', mask_count)\n",
        "mask = np.zeros_like(detection_masks_reframed[0])\n",
        "for i in range(mask_count):\n",
        "  if result['detection_scores'][0][i] \u003e= min_score_thresh:\n",
        "    mask += detection_masks_reframed[i]\n",
        "\n",
        "mask = tf.clip_by_value(mask, 0,1)\n",
        "plt.figure(figsize=(24,32))\n",
        "plt.imshow(mask,cmap='gray')\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "saved_model_inference.ipynb",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
